{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from fastai.vision import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ImageDataBunch;\n",
       "\n",
       "Train: LabelList (12396 items)\n",
       "x: ImageList\n",
       "Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28)\n",
       "y: CategoryList\n",
       "3,3,3,3,3\n",
       "Path: /home/user/.fastai/data/mnist_sample;\n",
       "\n",
       "Valid: LabelList (2038 items)\n",
       "x: ImageList\n",
       "Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28),Image (1, 28, 28)\n",
       "y: CategoryList\n",
       "3,3,3,3,3\n",
       "Path: /home/user/.fastai/data/mnist_sample;\n",
       "\n",
       "Test: None"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "DATA = untar_data(URLs.MNIST_SAMPLE)\n",
    "src = (ImageList.from_folder(DATA, convert_mode='L')\n",
    "                .split_by_folder(valid='valid')\n",
    "                .label_from_folder())\n",
    "data = (src.databunch(bs=8, num_workers=0)\n",
    "           .normalize((tensor(0.128), tensor(0.305))))\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "def get_mdl(actn:Callable):\n",
    "    layers = [conv2d(1, 16, stride=2), actn(), conv2d(16, 32, stride=2), actn(), AdaptiveConcatPool2d(1), Flatten(), nn.Linear(64, data.c)]\n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "  (1): MishCuda()\n",
       "  (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "  (3): MishCuda()\n",
       "  (4): AdaptiveConcatPool2d(\n",
       "    (ap): AdaptiveAvgPool2d(output_size=1)\n",
       "    (mp): AdaptiveMaxPool2d(output_size=1)\n",
       "  )\n",
       "  (5): Flatten()\n",
       "  (6): Linear(in_features=64, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from mish_cuda import *\n",
    "mdl_mish = get_mdl(MishCuda)\n",
    "mdl_mish"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "lrn_mish = Learner(data, mdl_mish, metrics=[accuracy])\n",
    "cbs = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.184366</td>\n",
       "      <td>0.161153</td>\n",
       "      <td>0.951914</td>\n",
       "      <td>00:06</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.079940</td>\n",
       "      <td>0.074344</td>\n",
       "      <td>0.973503</td>\n",
       "      <td>00:06</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.033569</td>\n",
       "      <td>0.042023</td>\n",
       "      <td>0.986752</td>\n",
       "      <td>00:06</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.045685</td>\n",
       "      <td>0.036185</td>\n",
       "      <td>0.986752</td>\n",
       "      <td>00:06</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.037368</td>\n",
       "      <td>0.037822</td>\n",
       "      <td>0.985770</td>\n",
       "      <td>00:06</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.017288</td>\n",
       "      <td>0.028351</td>\n",
       "      <td>0.989205</td>\n",
       "      <td>00:06</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.015544</td>\n",
       "      <td>0.025071</td>\n",
       "      <td>0.992640</td>\n",
       "      <td>00:06</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.008614</td>\n",
       "      <td>0.034780</td>\n",
       "      <td>0.987733</td>\n",
       "      <td>00:06</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.011460</td>\n",
       "      <td>0.023918</td>\n",
       "      <td>0.994112</td>\n",
       "      <td>00:06</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.006774</td>\n",
       "      <td>0.024139</td>\n",
       "      <td>0.994112</td>\n",
       "      <td>00:06</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "lrn_mish.fit_one_cycle(10, 1e-3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-fastai]",
   "language": "python",
   "name": "conda-env-.conda-fastai-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
